{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## KL Divergence\n",
    "\n",
    "KL Divergence, or Kullback-Leibler Divergence, measures the dissimilarity between two probability density functions. \n",
    "$$ D(X, Y) = \\sum_{x\\in D} p_{X}(x) \\log\\Big(\\frac{p_{X}(x)}{p_{Y}(x)}\\Big) $$\n",
    "\n",
    "Some properties of KL Divergence:\n",
    "\n",
    "- $ D(X, X) = 0$, KLD between a distribution and itself is 0\n",
    "- $ D(X, Y) >= 0$, KLD is always greater than 0\n",
    "- $ D(X, Y) \\ne D(Y, X)$, KLD is not symmetric\n",
    "\n",
    "Let's see some of these through PyTorch code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from torch.distributions import Normal, Uniform\n",
    "from torch.distributions import kl_divergence\n",
    "\n",
    "%matplotlib widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KLD between p and p: 0.0\n",
      "KLD between p and q: 0.3181471824645996\n",
      "KLD between q and p: 0.8068528175354004\n",
      "KLD between p and r: 0.9175443649291992\n"
     ]
    }
   ],
   "source": [
    "# Let us consider 3 normal distributions\n",
    "p = Normal(0, 5)\n",
    "q = Normal(0, 10)\n",
    "r = Normal(0, 20)\n",
    "\n",
    "kld_p_p = kl_divergence(p, p)\n",
    "kld_p_q = kl_divergence(p, q)\n",
    "kld_q_p = kl_divergence(q, p)\n",
    "kld_p_r = kl_divergence(p, r)\n",
    "\n",
    "print('KLD between p and p: {}'.format(kld_p_p))\n",
    "print('KLD between p and q: {}'.format(kld_p_q))\n",
    "print('KLD between q and p: {}'.format(kld_q_p))\n",
    "print('KLD between p and r: {}'.format(kld_p_r))\n",
    "\n",
    "# As expected, the KL divergence between the 2 identical distributions is 0\n",
    "assert kld_p_p == 0\n",
    "\n",
    "# Note that KL divergence is not a distance, which means it is not symmetric\n",
    "assert kld_p_q != kld_q_p\n",
    "\n",
    "# KL Divergence between p and q is less than that between p and r.\n",
    "assert kld_p_q < kld_p_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kld_plot(p, q, r):\n",
    "    x = np.linspace(-100, 100, 1000)\n",
    "    p_pdf = p.log_prob(torch.from_numpy(x)).exp()\n",
    "    q_pdf = q.log_prob(torch.from_numpy(x)).exp()\n",
    "    r_pdf = r.log_prob(torch.from_numpy(x)).exp()\n",
    "\n",
    "    kld_p_q = kl_divergence(p, q)\n",
    "    kld_p_r = kl_divergence(p, r)\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    ax.set_title(\"KLD(p, q) = {:.2f}. KLD(p, r) = {:.2f}\".format(kld_p_q , kld_p_r))\n",
    "    ax.set_ylabel(\"P(X)\")\n",
    "    ax.set_xlabel(\"X\")\n",
    "    ax.plot(x, p_pdf, 'green', label=\"p distribution\")\n",
    "    ax.plot(x, q_pdf, 'blue', label=\"q distribution\")\n",
    "    ax.plot(x, r_pdf, 'orange', label=\"r distribution\")\n",
    "    ax.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2704371f272d49019d2842aa52f70f6b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# As expected, KLD between p and q is less than KLD between p and r.\n",
    "kld_plot(p, q, r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "acdfeb3f9dd64336b0c8787ed54fe684",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# KL Divergence works for distributions that are not Normal as well.\n",
    "# Consider the following scenario where p is a Uniform distribution whereas q and r and Normal.\n",
    "\n",
    "p = Uniform(-20, 20, validate_args=False)\n",
    "q = Normal(0, 20)\n",
    "r = Normal(-50, 20)\n",
    "\n",
    "# As expected, KLD)(p, q) < KLD (p, r)\n",
    "kld_plot(p, q, r)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
